package org.mockserver.server.unification; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.codec.http.HttpContentDecompressor; import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpServerCodec; import io.netty.handler.logging.LoggingHandler; import io.netty.handler.ssl.SslHandler; import io.netty.util.AttributeKey; import org.mockserver.socket.SSLFactory; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * @author jamesdbloom */ @ChannelHandler.Sharable public abstract class PortUnificationHandler extends SimpleChannelInboundHandler<ByteBuf> { public static final AttributeKey<Boolean> SSL_ENABLED = AttributeKey.valueOf("SSL_ENABLED"); private final Logger logger = LoggerFactory.getLogger(this.getClass()); public PortUnificationHandler() { super(false); } @Override protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception { // Will use the first five bytes to detect a protocol. if (msg.readableBytes() < 3) { return; } if (isSsl(msg)) { enableSsl(ctx, msg); } else if (isHttp(msg)) { switchToHttp(ctx, msg); } else { // Unknown protocol; discard everything and close the connection. msg.clear(); ctx.close(); } if (logger.isTraceEnabled()) { if (ctx.pipeline().get(org.mockserver.logging.LoggingHandler.class) != null) { ctx.pipeline().remove(org.mockserver.logging.LoggingHandler.class); } if (ctx.pipeline().get(SslHandler.class) != null) { ctx.pipeline().addAfter("SslHandler#0", "LoggingHandler#0", new org.mockserver.logging.LoggingHandler(logger)); } else { ctx.pipeline().addFirst(new org.mockserver.logging.LoggingHandler(logger)); } } } private boolean isSsl(ByteBuf buf) { return buf.readableBytes() >= 5 && SslHandler.isEncrypted(buf); } private boolean isHttp(ByteBuf msg) { int letterOne = (int) msg.getUnsignedByte(msg.readerIndex()); int letterTwo = (int) msg.getUnsignedByte(msg.readerIndex() + 1); int letterThree = (int) msg.getUnsignedByte(msg.readerIndex() + 2); return letterOne == 'G' && letterTwo == 'E' && letterThree == 'T' || // GET letterOne == 'P' && letterTwo == 'O' && letterThree == 'S' || // POST letterOne == 'P' && letterTwo == 'U' && letterThree == 'T' || // PUT letterOne == 'H' && letterTwo == 'E' && letterThree == 'A' || // HEAD letterOne == 'O' && letterTwo == 'P' && letterThree == 'T' || // OPTIONS letterOne == 'P' && letterTwo == 'A' && letterThree == 'T' || // PATCH letterOne == 'D' && letterTwo == 'E' && letterThree == 'L' || // DELETE letterOne == 'T' && letterTwo == 'R' && letterThree == 'A' || // TRACE letterOne == 'C' && letterTwo == 'O' && letterThree == 'N'; // CONNECT } private void enableSsl(ChannelHandlerContext ctx, ByteBuf msg) { ChannelPipeline pipeline = ctx.pipeline(); pipeline.addFirst(new SslHandler(SSLFactory.createServerSSLEngine())); ctx.channel().attr(PortUnificationHandler.SSL_ENABLED).set(Boolean.TRUE); // re-unify (with SSL enabled) ctx.pipeline().fireChannelRead(msg); } private void switchToHttp(ChannelHandlerContext ctx, ByteBuf msg) { ChannelPipeline pipeline = ctx.pipeline(); addLastIfNotPresent(pipeline, new HttpServerCodec()); addLastIfNotPresent(pipeline, new HttpContentDecompressor()); addLastIfNotPresent(pipeline, new HttpContentLengthRemover()); addLastIfNotPresent(pipeline, new HttpObjectAggregator(Integer.MAX_VALUE)); if (logger.isDebugEnabled()) { addLastIfNotPresent(pipeline, new LoggingHandler()); } configurePipeline(ctx, pipeline); pipeline.remove(this); // fire message back through pipeline ctx.fireChannelRead(msg); } protected void addLastIfNotPresent(ChannelPipeline pipeline, ChannelHandler channelHandler) { if (pipeline.get(channelHandler.getClass()) == null) { pipeline.addLast(channelHandler); } } protected abstract void configurePipeline(ChannelHandlerContext ctx, ChannelPipeline pipeline); }